package pattern.gateway.filter;

import cn.hutool.core.date.StopWatch;
import cn.hutool.core.lang.UUID;
import io.netty.buffer.UnpooledByteBufAllocator;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import pattern.gateway.entity.Log;
import pattern.utils.ServletTool;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.net.URI;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

/**
 * 日志信息打印
 * @author Simon
 */
@Slf4j
@Component
public class RequestLogGlobalFilter implements GlobalFilter, Ordered {

    @Autowired
    ServerCodecConfigurer codecConfigurer;
    private static final String REQUEST_ID = "request_id";
    private static final String START_TIME_KEY = "start_time_key";

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        long startTime = System.currentTimeMillis();
        try {
            ServerHttpRequest request = exchange.getRequest();
            // 设置X-Request-Id
            AtomicReference<String> requestId = new AtomicReference<>(UUID.fastUUID().toString(Boolean.TRUE));
            Consumer<HttpHeaders> httpHeadersConsumer = httpHeaders -> {
                String headerRequestId = request.getHeaders().getFirst(REQUEST_ID);
                if (StringUtils.isBlank(headerRequestId)) {
                    httpHeaders.set(REQUEST_ID, requestId.get());
                } else {
                    requestId.set(headerRequestId);
                }
                httpHeaders.set(START_TIME_KEY, String.valueOf(startTime));
            };
            ServerRequest serverRequest = ServerRequest.create(exchange,
                    codecConfigurer.getReaders());
            URI requestUri = request.getURI();
            String uriQuery = requestUri.getQuery();
            String url = requestUri.getPath() + (StringUtils.isNotBlank(uriQuery) ? "?" + uriQuery : "");
            HttpHeaders headers = request.getHeaders();
            MediaType mediaType = headers.getContentType();
            String method = request.getMethodValue().toUpperCase();

            // 原始请求体
            final AtomicReference<String> requestBody = new AtomicReference<>();
            final AtomicBoolean newBody = new AtomicBoolean(false);
            if (Objects.nonNull(mediaType) && ServletTool.isUploadFile(mediaType)) {
                requestBody.set("上传文件");
            } else {
                if (method.equalsIgnoreCase(HttpMethod.GET.name())) {
                    if (StringUtils.isNotBlank(uriQuery)) {
                        requestBody.set(uriQuery);
                    }
                } else {
                    newBody.set(true);
                }
            }
            final Log logDTO = new Log();
            logDTO.setRequestUrl(url);
            logDTO.setRequestBody(requestBody.get());
            logDTO.setRequestMethod(method);
            logDTO.setRequestId(requestId.get());
            logDTO.setIp(ServletTool.getIpAddr(request));

            var response = exchange.getResponse();
            DataBufferFactory bufferFactory = response.bufferFactory();
            ServerHttpRequest serverHttpRequest = exchange.getRequest().mutate().headers(httpHeadersConsumer).build();
            ServerWebExchange build = exchange.mutate().request(serverHttpRequest).build();
            return build.getSession().flatMap(webSession -> {
                logDTO.setSessionId(webSession.getId());
                logDTO.setToken(ServletTool.getToken(request));
                if (newBody.get() && headers.getContentLength() > 0) {
                    Mono<String> bodyToMono = serverRequest.bodyToMono(String.class);
                    return bodyToMono.flatMap(reqBody -> {
                        logDTO.setRequestBody(reqBody);
                        // 重写原始请求
                        ServerHttpRequestDecorator requestDecorator = new ServerHttpRequestDecorator(exchange.getRequest()) {
                            @Override
                            public Flux<DataBuffer> getBody() {
                                NettyDataBufferFactory nettyDataBufferFactory = new NettyDataBufferFactory(new UnpooledByteBufAllocator(false));
                                DataBuffer bodyDataBuffer = nettyDataBufferFactory.wrap(reqBody.getBytes());
                                return Flux.just(bodyDataBuffer);
                            }
                        };
                        return chain.filter(exchange.mutate()
                                .request(requestDecorator)
                                .build()).then(doRecord(logDTO));
                    });
                } else {
                    return chain.filter(exchange).then(doRecord(logDTO));
                }
            });

        } catch (Exception e) {
            log.error("请求日志打印出现异常", e);
            return chain.filter(exchange);
        }
    }

    public static Mono<Void> doRecord(Log dto) {
        log.info(dto.toString());
        return Mono.empty();
    }

    @Override
    public int getOrder() {
        return Ordered.HIGHEST_PRECEDENCE;
    }
}
